{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  CNN做MNIST分类\n",
    "refer: https://github.com/hjptriplebee/AlexNet_with_tensorflow/blob/master/alexnet.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.5/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "# 设置按需使用GPU\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.InteractiveSession(config=config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.导入数据，用 tensorflow 导入"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use the retry module or similar alternatives.\n",
      "WARNING:tensorflow:From <ipython-input-3-0b71c854ee82>:3: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please write your own downloading logic.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:219: retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use urllib or similar directly.\n",
      "Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.data to implement this functionality.\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use tf.one_hot on tensors.\n",
      "Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.5/dist-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use alternatives such as official/mnist/dataset.py from tensorflow/models.\n",
      "(10000, 10)\n",
      "(55000, 10)\n"
     ]
    }
   ],
   "source": [
    "# 用tensorflow 导入数据\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "mnist = input_data.read_data_sets('MNIST_data', one_hot=True)\n",
    "# 看看咱们样本的数量\n",
    "print(mnist.test.labels.shape)\n",
    "print(mnist.train.labels.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 构建网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finished building network.\n"
     ]
    }
   ],
   "source": [
    "# 定义卷积层\n",
    "def conv2d(x, filter_shape, strides_x, strides_y, padding, name):\n",
    "    \"\"\" \n",
    "    Args:\n",
    "        x: 4-D inputs. [batch_size, in_height, in_width, in_channels]\n",
    "        filter_shape: A list of ints.[filter_height, filter_width, in_channels, out_channels]\n",
    "        strides: A list of ints. 1-D tensor of length 4. The stride of the sliding window for each dimension of input.\n",
    "        padding: A string from: \"SAME\", \"VALID\". The type of padding algorithm to use.\n",
    "    Returns:\n",
    "        h_conv:  A 4-D tensor. [batch_size, out_height, out_width, out_channels]. \n",
    "        if padding is 'SAME', then out_height==in_height. \n",
    "        else, out_height = in_height - filter_height + 1.\n",
    "        the same for out_width.\n",
    "    \"\"\"\n",
    "    assert padding in ['SAME', 'VALID']\n",
    "    strides=[1,strides_x, strides_y,1]\n",
    "    with tf.variable_scope(name):\n",
    "        W_conv = tf.get_variable('w', shape=filter_shape)\n",
    "        b_conv = tf.get_variable('b', shape=[filter_shape[-1]])\n",
    "        h_conv = tf.nn.conv2d(x, W_conv, strides=strides, padding=padding)\n",
    "        h_conv_relu = tf.nn.relu(h_conv + b_conv)\n",
    "    return h_conv_relu\n",
    "    \n",
    "\n",
    "def max_pooling(x, k_height, k_width, strides_x, strides_y, padding='SAME'):\n",
    "    \"\"\"max pooling layer.\"\"\"\n",
    "    ksize=[1,k_height, k_width,1]\n",
    "    strides=[1,strides_x, strides_y,1]\n",
    "    h_pool = tf.nn.max_pool(x, ksize, strides, padding)\n",
    "    return h_pool\n",
    "\n",
    "def dropout(x, keep_prob, name=None):\n",
    "    \"\"\"dropout layer\"\"\"\n",
    "    return tf.nn.dropout(x, keep_prob, name=name)\n",
    "\n",
    "def fc(x, in_size, out_size, name, activation=None):\n",
    "    \"\"\"fully-connect\n",
    "    Args:\n",
    "        x: 2-D tensor, [batch_size, in_size]\n",
    "        in_size: the size of input tensor.\n",
    "        out_size: the size of output tensor.\n",
    "        activation: 'relu' or 'sigmoid' or 'tanh'.\n",
    "    Returns:\n",
    "        h_fc: 2-D tensor, [batch_size, out_size].\n",
    "    \"\"\"\n",
    "    if activation is not None:\n",
    "        assert activation in ['relu', 'sigmoid', 'tanh'], 'Wrong activation function.'\n",
    "    with tf.variable_scope(name):\n",
    "        w = tf.get_variable('w', shape = [in_size, out_size], dtype=tf.float32)\n",
    "        b = tf.get_variable('b', [out_size], dtype=tf.float32)\n",
    "        h_fc = tf.nn.xw_plus_b(x, w, b)\n",
    "        if activation == 'relu':\n",
    "            return tf.nn.relu(h_fc)\n",
    "        elif activation == 'tanh':\n",
    "            return tf.nn.tanh(h_fc)\n",
    "        elif activation == 'sigmoid':\n",
    "            return tf.nn.sigmoid(h_fc)\n",
    "        else:\n",
    "            return h_fc\n",
    "\n",
    "with tf.name_scope('inputs'):\n",
    "    X_ = tf.placeholder(tf.float32, [None, 784])\n",
    "    y_ = tf.placeholder(tf.float32, [None, 10])\n",
    "\n",
    "# 把X转为卷积所需要的形式\n",
    "X = tf.reshape(X_, [-1, 28, 28, 1])\n",
    "h_conv1 = conv2d(X, [5, 5, 1, 32], 1, 1, 'SAME', 'conv1')\n",
    "h_pool1 = max_pooling(h_conv1, 2, 2, 2, 2)\n",
    "\n",
    "h_conv2 = conv2d(h_pool1, [5, 5, 32, 64], 1, 1, 'SAME', 'conv2')\n",
    "h_pool2 = max_pooling(h_conv2, 2, 2, 2, 2)\n",
    "\n",
    "# flatten\n",
    "h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])\n",
    "h_fc1 = fc(h_pool2_flat, 7*7*64, 1024, 'fc1', 'relu')\n",
    "\n",
    "# dropout: 输出的维度和h_fc1一样，只是随机部分值被值为零\n",
    "keep_prob = tf.placeholder(tf.float32)\n",
    "h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)\n",
    "h_fc2 = fc(h_fc1_drop, 1024, 10, 'fc2')\n",
    "y_conv = tf.nn.softmax(h_fc2)\n",
    "print('Finished building network.')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.训练和评估"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<b> 在测试的时候不使用 mini_batch， 那么测试的时候会占用较多的GPU（4497M），这在 notebook 交互式编程中是不推荐的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "step 0, training accuracy 0.06\n",
      "step 1000, training accuracy 0.94\n",
      "step 2000, training accuracy 1\n",
      "step 3000, training accuracy 1\n",
      "step 4000, training accuracy 1\n",
      "step 5000, training accuracy 1\n",
      "step 6000, training accuracy 1\n",
      "step 7000, training accuracy 1\n",
      "step 8000, training accuracy 0.98\n",
      "step 9000, training accuracy 1\n",
      "test accuracy 0.9922\n"
     ]
    }
   ],
   "source": [
    "cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))\n",
    "train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)\n",
    "\n",
    "correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, \"float\"))\n",
    "sess.run(tf.global_variables_initializer())\n",
    "for i in range(10000):\n",
    "    batch = mnist.train.next_batch(50)\n",
    "    if i%1000 == 0:\n",
    "        train_accuracy = accuracy.eval(feed_dict={\n",
    "            X_:batch[0], y_: batch[1], keep_prob: 1.0})\n",
    "        print(\"step %d, training accuracy %g\"%(i, train_accuracy))\n",
    "    train_step.run(feed_dict={X_: batch[0], y_: batch[1], keep_prob: 0.5})\n",
    "\n",
    "print(\"test accuracy %g\"%accuracy.eval(feed_dict={\n",
    "    X_: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<b> 下面改成了 test 也用 mini_batch 的形式， 显存只用了 529M,所以还是很成功的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From <ipython-input-6-bbf989c4bb74>:11: arg_max (from tensorflow.python.ops.gen_math_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `argmax` instead\n",
      "step 0, training acc 0.1\n",
      "step 500, training acc 0.96\n",
      "step 1000, training acc 0.86\n",
      "step 1500, training acc 0.98\n",
      "step 2000, training acc 0.98\n",
      "step 2500, training acc 0.98\n",
      "step 3000, training acc 0.98\n",
      "step 3500, training acc 1\n",
      "step 4000, training acc 1\n",
      "step 4500, training acc 0.98\n",
      "testing step 20, test_acc_sum 19.86\n",
      "testing step 40, test_acc_sum 39.58\n",
      "testing step 60, test_acc_sum 59.34\n",
      "testing step 80, test_acc_sum 79.08\n",
      "testing step 100, test_acc_sum 98.89\n",
      " test_accuracy 0.9889\n"
     ]
    }
   ],
   "source": [
    "# 题外话：在做这个例子的过程中遇到过：资源耗尽的错误，为什么？\n",
    "# -> 因为之前每次做 train_acc  的时候用了全部的 55000 张图，显存爆了.\n",
    "\n",
    "# 1.损失函数：cross_entropy\n",
    "cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))\n",
    "# 2.优化函数：AdamOptimizer\n",
    "train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)\n",
    "\n",
    "# 3.预测准确结果统计\n",
    "#　预测值中最大值（１）即分类结果，是否等于原始标签中的（１）的位置。argmax()取最大值所在的下标\n",
    "correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.arg_max(y_, 1))  \n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
    "\n",
    "\n",
    "# 如果一次性来做测试的话，可能占用的显存会比较多，所以测试的时候也可以设置较小的batch来看准确率\n",
    "test_acc_sum = tf.Variable(0.0)\n",
    "batch_acc = tf.placeholder(tf.float32)\n",
    "new_test_acc_sum = tf.add(test_acc_sum, batch_acc)\n",
    "update = tf.assign(test_acc_sum, new_test_acc_sum)\n",
    "\n",
    "# 定义了变量必须要初始化，或者下面形式\n",
    "sess.run(tf.global_variables_initializer())\n",
    "# 或者某个变量单独初始化 如：\n",
    "# x.initializer.run()\n",
    "\n",
    "# 训练\n",
    "for i in range(5000):\n",
    "    X_batch, y_batch = mnist.train.next_batch(batch_size=50)\n",
    "    if i % 500 == 0:\n",
    "        train_accuracy = accuracy.eval(feed_dict={X_: X_batch, y_: y_batch, keep_prob: 1.0})\n",
    "        print(\"step %d, training acc %g\" % (i, train_accuracy))\n",
    "    train_step.run(feed_dict={X_: X_batch, y_: y_batch, keep_prob: 0.5})  \n",
    "\n",
    "# 全部训练完了再做测试，batch_size=100\n",
    "for i in range(100): \n",
    "    X_batch, y_batch = mnist.test.next_batch(batch_size=100)\n",
    "    test_acc = accuracy.eval(feed_dict={X_: X_batch, y_: y_batch, keep_prob: 1.0})\n",
    "    update.eval(feed_dict={batch_acc: test_acc})\n",
    "    if (i+1) % 20 == 0:\n",
    "        print(\"testing step %d, test_acc_sum %g\" % (i+1, test_acc_sum.eval()))\n",
    "print(\" test_accuracy %g\" % (test_acc_sum.eval() / 100.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
